"""
The Refer-Judge Method
"""

import sys
import multiprocessing
import json
from tqdm import  tqdm
import os

from SYSTEM_MESSAGE import (SYSTEM_TEXT_SENTENCE_PARADOX, SYSTEM_TEXT_SENTENCE_GROUP_1, SYSTEM_TEXT_SENTENCE_GROUP_2,
                            SYSTEM_TEXT_IMAGE_EXIST, SYSTEM_TEXT_IMAGE_UNIQUENESS,
                            SYSTEM_TEXT_REFER_REFINEMENT_POS, SYSTEM_TEXT_REFER_REFINEMENT_POS_UNIQUE,
                            SYSTEM_TEXT_REFER_REFINEMENT_MIS, SYSTEM_TEXT_REFER_JUDGE)

from ref_jobs import item_filter, item_uniqueness_filter

from config import batch_size, JOB_MARK_List, model_name, org_data, save_name, log_file, record_token_mark, dirt_name
model = model_name

os.makedirs(dirt_name, exist_ok=True)
if not os.path.exists(log_file):
    with open(log_file, 'w') as f:
        pass
# ===============================

for JOB_MARK in tqdm(JOB_MARK_List):
    # ==================
    LLM_JOBS = ['pos_frame_exist', 'pos_frame_unique', 'mis_frame_exist', 'sentence_logical', 'group_logical',
                'ref_refine_exist_pos', 'ref_refine_unique_pos', 'ref_refine_exist_mis', 'refer_judge']  # 0-4 / 5-8
    FRAME_NUMBER = [8, 8, 12, 0, 0, 0, 0, 0, 0]
    TOKEN_NUMBER = [1500] * 9

    JSON_KEYS = [f'{model}_positive_image_existence', f'{model}_positive_uniqueness', f'{model}_misleading_image_existence', f'{model}_sentence_logical', f'{model}_group_logical',
                 f'{model}_refine_exist_pos', f'{model}_refine_unique_pos', f'{model}_refine_exist_mis', f'{model}_final_score']
    MESSAGES = [SYSTEM_TEXT_IMAGE_EXIST, SYSTEM_TEXT_IMAGE_UNIQUENESS, SYSTEM_TEXT_IMAGE_EXIST, SYSTEM_TEXT_SENTENCE_PARADOX, [SYSTEM_TEXT_SENTENCE_GROUP_1, SYSTEM_TEXT_SENTENCE_GROUP_2],
                SYSTEM_TEXT_REFER_REFINEMENT_POS, SYSTEM_TEXT_REFER_REFINEMENT_POS_UNIQUE, SYSTEM_TEXT_REFER_REFINEMENT_MIS, SYSTEM_TEXT_REFER_JUDGE]

    JOB_FRAME_NUMBER, JOB_TOKEN_NUMBER = FRAME_NUMBER[JOB_MARK], TOKEN_NUMBER[JOB_MARK]
    LLM_JOB, JSON_KEY, System_Message = LLM_JOBS[JOB_MARK], JSON_KEYS[JOB_MARK], MESSAGES[JOB_MARK]

    if JOB_MARK == 2 or JOB_MARK == 7:
        pos_mis = 'misleading_frames'
    else:
        pos_mis = 'positive_frames'

    # =============================
    if JOB_MARK > 0:
        with open(save_name.format(model, str(JOB_MARK - 1))) as f:
            mid_data = json.load(f)
    job_data = org_data[:10] if JOB_MARK == 0 else mid_data

    """ data clean """
    if not LLM_JOB == f'{model}_positive_uniqueness':
        filtered_data, good_data = item_filter(job_data, JSON_KEY)
    else:
        filtered_data = item_uniqueness_filter(job_data, JSON_KEY)

    filtered_data = filtered_data

    # =============================
    """ import jobs """
    if 'frame_exist' in LLM_JOB:
        from ref_jobs import image_exist as job_iter
    elif 'frame_unique' in LLM_JOB:
        from ref_jobs import image_uniqueness as job_iter
    elif 'ref_refine' in LLM_JOB:
        from ref_jobs import refer_refine as job_iter
    elif 'refer_judge' in LLM_JOB:
        from ref_jobs import refer_judge as job_iter
    elif 'group_logical' in LLM_JOB:
        from ref_jobs import sentence_group_judge as job_iter
    else:
        from ref_jobs import sentence_judge as job_iter


    def safe_job_iter(data_item, System_Message, JSON_KEY, pos_mis, frame_num, token_num, lock):
        try:
            gpt_back = job_iter(data_item, System_Message, JSON_KEY, pos_mis, frame_num, token_num, lock=lock)
            return gpt_back
        except Exception as e:
            # print(f"Error processing item {data_item['item_id']}: {str(e)}")
            print({str(e)})
            return data_item

    # ============================
    # ==== log file =====
    with open(log_file, "a") as f:
        f.write(f'start {LLM_JOB} job' + "\n")
    # =====================
    """ multiprocessing """
    with multiprocessing.Manager() as manager:
        results = manager.list()  # 共享列表
        lock = manager.Lock() if record_token_mark else None

        with multiprocessing.Pool(processes=batch_size) as pool:
            for i in tqdm(range(0, len(filtered_data), batch_size), leave=False):
                batch_data = filtered_data[i: i + batch_size]
                batch_message = [(x, System_Message, JSON_KEY, pos_mis, JOB_FRAME_NUMBER, JOB_TOKEN_NUMBER, lock) for x in batch_data]

                batch_results = pool.starmap(safe_job_iter, batch_message)
                results.extend([r for r in batch_results if r is not None])

        with open(save_name.format(model, str(JOB_MARK)), 'w') as f:
            json.dump(list(results), f)

    print(f"{LLM_JOB} processing complete.")

# =============================